#include <torch/csrc/jit/backends/backend_init.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/python/pybind_utils.h>

namespace torch {
namespace jit {

void initJitBackendBindings(PyObject* module) {
  // Register a backend registration callback that binds a function for
  // lowering to each backend. For example, if there is a registered backend
  // called "example_backend", this callback generates a binding called
  // torch._C._jit_to_example_backend.
  detail::addBackendRegistrationCallback([=](const std::string& name) {
    auto m = py::handle(module).cast<py::module>();
    std::string binding_name("_jit_to_");
    binding_name.append(name);
    m.def(
        binding_name.c_str(),
        [=](const Module& orig_module, const py::dict& method_compile_spec) {
          const c10::QualifiedName qual_backend_name(
              {"__torch__",
               "torch",
               "classes",
               detail::kBackendsNamespace,
               name});
          const std::string& backend_name = qual_backend_name.name();
          // TODO: Validate method_compile_spec.

          // Clone orig_module to make sure backend transformation is
          // functional.
          auto cloned_module = orig_module.clone();

          // Represents of a Type of Dict[str, Any].
          auto any_dict_ty =
              DictType::create(StringType::get(), AnyType::get());

          // Generate LoweredModule.
          Module loweredModule(
              "torch.jit." + backend_name + "LoweredModule",
              get_python_cu(),
              /*shouldMangle=*/true);

          // Generate attributes.
          // This is the original cloned and preprocessed module.
          loweredModule.register_attribute(
              "__processed_module",
              AnyType::get(),
              cloned_module._ivalue(),
              /*is_param=*/false);

          // This is for the method_compile_spec passed in to to_<backend> or
          // loaded from an exported model.
          loweredModule.register_attribute(
              "__method_compile_spec",
              any_dict_ty,
              toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
              /*is_param=*/false);

          // This is a pointer to a backend instance that is used to access
          // compile and execute functions.
          auto cls = getCustomClass(qual_backend_name.qualifiedName());
          TORCH_INTERNAL_ASSERT(cls);
          c10::intrusive_ptr<torch::CustomClassHolder> backend;
          loweredModule.register_attribute(
              "__backend", cls, IValue::make_capsule(backend));

          // This is the list of opaque backend handles returned by
          // backend.compile.
          loweredModule.register_attribute(
              "__handles",
              any_dict_ty,
              c10::impl::GenericDict(
                  any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
              /*is_param=*/false);

          // Methods.

          // This is a helper function for creating a new instance of the
          // backend class.
          static const auto create_backend_ct = CodeTemplate(R"(
            def __create_backend(self):
                self.__backend = $name()
            )");
          TemplateEnv create_backend_te;
          create_backend_te.s("name", qual_backend_name.qualifiedName());
          loweredModule.define(
              create_backend_ct.format(create_backend_te),
              loweredModuleResolver());

          // getstate and setstate are for serialization/deserialization of
          // the LoweredModule.
          loweredModule.define(
              R"(
            def __getstate__(self):
                return self.__method_compile_spec, self.__processed_module
            )",
              loweredModuleResolver());

          loweredModule.define(
              R"(
            def __setstate__(self, state):
                self.__method_compile_spec = state[0]
                self.__processed_module = state[1]
                self.__create_backend()
                self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
            )",
              loweredModuleResolver());

          // This is never called during compilation or execution, but is
          // needed to generate the LoweredModule because we don't have access
          // to an instance of the backend as a C++ object with which to call
          // preprocess.
          loweredModule.define(
              R"(
            def __preprocess(self, mod: Any, method_compile_spec: Dict[str, Any]):
                self.__create_backend()
                self.__processed_module = self.__backend.preprocess(mod, method_compile_spec)
          )",
              loweredModuleResolver());

          // This loop generates one method on the LoweredModule for every key
          // in method_compile_spec.
          for (auto& e : method_compile_spec) {
            std::string method_name = py::cast<std::string>(e.first);
            static const auto method_ct = CodeTemplate(R"(
            def $method(self${,def_inputs}):
                typed_inputs: List[Any] = [${fwd_inputs,}]
                $ret, = self.__backend.execute(self.__handles["$method"], typed_inputs)
                ${refine,}
                return $ret
            )");

            TemplateEnv method_te;
            method_te.s("method", method_name);
            auto method = orig_module.get_method(method_name);
            auto& function = method.function();
            auto& schema = function.getSchema();

            // Generate the inputs for the function signature (def_inputs) and
            // for passing to backend.execute (fwd_inputs).
            std::vector<std::string> def_inputs, fwd_inputs;
            for (const auto& arg : schema.arguments()) {
              auto name = arg.name();

              // Skip self since that is only and always present in the
              // signature.
              if (name == "self") {
                continue;
              }

              auto default_value = arg.default_value();

              if (arg.kwarg_only()) {
                // If this is a kwarg, it needs to be emitted as keyword=value
                // in the definition and keyword=keyword in the call to
                // backend_execute.
                TORCH_INTERNAL_ASSERT(default_value.has_value());
                std::stringstream def_ss, fwd_ss;
                def_ss << name << "=";
                fwd_ss << name << "=" << name;
                default_value->repr(
                    def_ss,
                    [](std::ostream&, const IValue&) -> bool { return false; });
                def_inputs.emplace_back(def_ss.str());
                fwd_inputs.emplace_back(fwd_ss.str());
              } else {
                // If this is not a kwarg, it should be emitted as is in the
                // signature and the call to backend_execute.
                def_inputs.emplace_back(name);
                fwd_inputs.emplace_back(name);
              }
            }

            // Generate a comma-delimited list of identifiers to unpack
            // outputs, as well as a list of isinstance checks to make sure
            // the backend returned the types it was supposed to.
            std::stringstream out_ss, type_check_ss;
            std::vector<std::string> type_checks;
            TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
            auto out_ty = schema.returns().at(0).type();

            out_ss << "_0";
            type_check_ss << "assert isinstance(_0, ";

            if (auto out_tuple_ty = out_ty->cast<TupleType>()) {
              auto tuple_elements = out_tuple_ty->elements();
              type_check_ss << tuple_elements[0]->str() << ")";
              type_checks.emplace_back(type_check_ss.str());
              for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
                type_check_ss.str(std::string());
                type_check_ss.clear();
                out_ss << ", _" << i;
                type_check_ss << "assert isinstance(_" << i << ", "
                              << tuple_elements[i]->str() << ")";
                type_checks.emplace_back(type_check_ss.str());
              }
            } else {
              type_check_ss << out_ty->str() << ")";
              type_checks.emplace_back(type_check_ss.str());
            }

            method_te.v("def_inputs", def_inputs);
            method_te.v("fwd_inputs", fwd_inputs);
            method_te.v("refine", type_checks);
            method_te.s("ret", out_ss.str());

            loweredModule.define(
                method_ct.format(method_te), loweredModuleResolver());
          }

          // Run preprocess so that __processed_module is set correctly before
          // compilation.
          loweredModule.run_method(
              "__preprocess",
              cloned_module._ivalue(),
              toIValue(method_compile_spec, any_dict_ty).toGenericDict());

          // Call __setstate__ to ensure that the returned Module is ready to
          // run.
          auto state = at::ivalue::Tuple::create(
              toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
              loweredModule.attr("__processed_module"));
          loweredModule.run_method("__setstate__", state);
          return loweredModule;
        });
  });
}
} // namespace jit
} // namespace torch
